import joblib
from pathlib import Path
from omegaconf import OmegaConf
import pandas as pd
import shap
from sklearn.metrics import classification_report
EXPERIMENT_ROOT = "../../experiments/rf_features_only"
RESULTS = Path(EXPERIMENT_ROOT) / "results.yaml"
results = OmegaConf.load(RESULTS)
HAS_SCALER = False
DATA_ROOT = "../../data/prepared/"
MODEL = Path(EXPERIMENT_ROOT) / "model.pkl"
SCALER = Path(EXPERIMENT_ROOT) / "scaler.pkl"
VAL_DATA = Path(DATA_ROOT) / "val_features.pkl"
TEST_DATA = Path(DATA_ROOT) / "test_features.pkl"
VEC_COLS = list(range(768))
val_df = pd.read_pickle(VAL_DATA)
val_df_labels = val_df.retweet_label
val_df.drop(["retweet_label", "id_str"], axis=1, inplace=True)
test_df = pd.read_pickle(TEST_DATA)
test_df_labels = test_df.retweet_label
test_df.drop(["retweet_label", "id_str"], axis=1, inplace=True)
model = joblib.load(MODEL)
if HAS_SCALER:
scaler = joblib.load(SCALER)
transformed_val = scaler.transform(val_df[VEC_COLS].values)
val_df[VEC_COLS] = transformed_val
transformed_test = scaler.transform(test_df[VEC_COLS].values)
test_df[VEC_COLS] = transformed_test
val_df.head()
| entities.urls | entities.media | user_in_net | has_covid_keyword | tweets_keywords_3_in_degree | tweets_keywords_3_out_degree | tweets_keywords_3_in_strength | tweets_keywords_3_out_strength | tweets_keywords_3_eigenvector_in | tweets_keywords_3_eigenvector_out | ... | users_reply_clustering | user.followers_isna | users_mention_isna | following_users_isna | users_reply_isna | log1p_num_hashtags | log1p_followers_count | log1p_friends_count | log1p_statuses_count | log1p_num_mentioned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 173185 | 0 | 1 | 0 | 0 | -0.647823 | -0.510160 | -0.967028 | 0.944173 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 1 | 1.481133 | 1.098612 | 0.762279 | 1.036245 | -0.241881 | 0.0 |
| 173186 | 0 | 1 | 0 | 0 | -0.647823 | -0.651175 | -0.967028 | -0.954585 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 1 | 1.481133 | 0.693147 | 0.762279 | 1.036245 | -0.241881 | 0.0 |
| 173187 | 0 | 0 | 0 | 0 | -0.647823 | -0.510160 | -0.967028 | 0.944173 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 1 | 1.481133 | 0.000000 | 0.762279 | 1.036245 | -0.241881 | 0.0 |
| 173188 | 0 | 1 | 0 | 0 | -0.647823 | -0.651175 | -0.967028 | -0.954585 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 1 | 1.481133 | 0.000000 | 0.762279 | 1.036245 | -0.241881 | 0.0 |
| 173189 | 0 | 1 | 0 | 0 | -0.647823 | -0.651175 | -0.967028 | -0.954585 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 1 | 1.481133 | 0.000000 | 0.762279 | 1.036245 | -0.241881 | 0.0 |
5 rows × 49 columns
test_df.head()
| entities.urls | entities.media | user_in_net | has_covid_keyword | tweets_keywords_3_in_degree | tweets_keywords_3_out_degree | tweets_keywords_3_in_strength | tweets_keywords_3_out_strength | tweets_keywords_3_eigenvector_in | tweets_keywords_3_eigenvector_out | ... | users_reply_clustering | user.followers_isna | users_mention_isna | following_users_isna | users_reply_isna | log1p_num_hashtags | log1p_followers_count | log1p_friends_count | log1p_statuses_count | log1p_num_mentioned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 194715 | 1 | 1 | 1 | 1 | 1.908895 | 2.059462 | 0.938154 | 0.944173 | -0.180504 | -0.054602 | ... | 0.791296 | 0 | 0.595178 | 0 | -0.675159 | 0.693147 | 1.203002 | 0.781949 | 0.956741 | 0.0 |
| 194716 | 1 | 1 | 1 | 0 | -0.106151 | -0.375587 | 1.083082 | 0.944173 | -0.180504 | -0.054618 | ... | 0.791296 | 0 | 0.595178 | 0 | -0.675159 | 0.000000 | 1.203002 | 0.781949 | 0.956741 | 0.0 |
| 194717 | 1 | 1 | 1 | 0 | -0.310304 | -0.247161 | 0.925793 | 0.944173 | -0.180504 | -0.054619 | ... | 0.791296 | 0 | 0.595178 | 0 | -0.675159 | 0.000000 | 1.203002 | 0.781949 | 0.956741 | 0.0 |
| 194718 | 1 | 1 | 1 | 1 | -0.206248 | 1.744481 | 0.925793 | 0.944173 | -0.180504 | -0.054610 | ... | 0.791296 | 0 | 0.595178 | 0 | -0.675159 | 0.693147 | 1.203002 | 0.781949 | 0.956741 | 0.0 |
| 194719 | 1 | 0 | 1 | 0 | -0.647823 | -0.651175 | -0.967028 | -0.954585 | -0.180504 | -0.054619 | ... | 0.791296 | 0 | 0.595178 | 0 | -0.675159 | 0.693147 | 1.203002 | 0.781949 | 0.956741 | 0.0 |
5 rows × 49 columns
val_predictions = model.predict(val_df.values)
val_out = classification_report(val_df_labels.values, val_predictions,
digits=3, output_dict=False)
print(val_out)
precision recall f1-score support
0 0.726 0.674 0.699 10954
1 0.634 0.690 0.661 8989
accuracy 0.681 19943
macro avg 0.680 0.682 0.680 19943
weighted avg 0.685 0.681 0.682 19943
[Parallel(n_jobs=24)]: Using backend ThreadingBackend with 24 concurrent workers. [Parallel(n_jobs=24)]: Done 2 tasks | elapsed: 0.0s [Parallel(n_jobs=24)]: Done 152 tasks | elapsed: 0.1s [Parallel(n_jobs=24)]: Done 200 out of 200 | elapsed: 0.1s finished
test_predictions = model.predict(test_df.values)
test_out = classification_report(test_df_labels.values, test_predictions,
digits=3, output_dict=False)
print(test_out)
precision recall f1-score support
0 0.716 0.640 0.676 10639
1 0.633 0.710 0.669 9305
accuracy 0.673 19944
macro avg 0.675 0.675 0.673 19944
weighted avg 0.677 0.673 0.673 19944
[Parallel(n_jobs=24)]: Using backend ThreadingBackend with 24 concurrent workers. [Parallel(n_jobs=24)]: Done 2 tasks | elapsed: 0.0s [Parallel(n_jobs=24)]: Done 152 tasks | elapsed: 0.1s [Parallel(n_jobs=24)]: Done 200 out of 200 | elapsed: 0.1s finished
explainer = shap.Explainer(model)
len(test_df)
19944
# sample for faster SHAP calculation
# typically, 100-1000 examples is ok
test_df_sample = test_df.sample(frac=0.05, random_state=42)
test_df_labels_sample = test_df_labels[test_df_sample.index]
test_df_sample
| entities.urls | entities.media | user_in_net | has_covid_keyword | tweets_keywords_3_in_degree | tweets_keywords_3_out_degree | tweets_keywords_3_in_strength | tweets_keywords_3_out_strength | tweets_keywords_3_eigenvector_in | tweets_keywords_3_eigenvector_out | ... | users_reply_clustering | user.followers_isna | users_mention_isna | following_users_isna | users_reply_isna | log1p_num_hashtags | log1p_followers_count | log1p_friends_count | log1p_statuses_count | log1p_num_mentioned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 196995 | 0 | 1 | 1 | 0 | -0.647823 | -0.651175 | -0.967028 | -0.954585 | -0.180504 | -0.054619 | ... | 1.341809 | 0 | 0.595178 | 0 | -0.675159 | 0.000000 | 0.001409 | 0.576489 | 0.647097 | 0.0 |
| 208764 | 0 | 1 | 0 | 0 | 0.639640 | 1.776070 | 1.004117 | 1.055509 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 1 | 1.481133 | 0.000000 | -0.203655 | -1.429558 | -0.643354 | 0.0 |
| 198537 | 1 | 0 | 1 | 1 | 0.171867 | 1.041619 | 0.925793 | 0.960493 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | -1.680170 | 0 | -0.675159 | 0.000000 | -1.295609 | -0.905254 | -0.341030 | 1.0 |
| 199457 | 0 | 0 | 1 | 0 | -0.647823 | -0.651175 | -0.967028 | -0.954585 | -0.180504 | -0.054619 | ... | 1.103448 | 0 | -1.680170 | 0 | -0.675159 | 0.000000 | -0.436851 | 0.158218 | -1.101759 | 1.0 |
| 204573 | 0 | 0 | 1 | 0 | 2.409977 | 2.188824 | 1.005758 | 1.028141 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 0 | -0.675159 | 2.079442 | -0.249862 | 0.215947 | 0.440486 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 213637 | 0 | 1 | 1 | 0 | 1.427123 | 2.108751 | 1.037142 | 1.027553 | -0.180504 | 0.734861 | ... | -0.797732 | 0 | 0.595178 | 0 | -0.675159 | 1.791759 | 0.129038 | 0.559467 | -0.564597 | 0.0 |
| 214221 | 1 | 0 | 1 | 0 | -0.206248 | 1.888641 | 1.174367 | 1.064132 | -0.180504 | -0.054109 | ... | 1.209000 | 0 | 0.595178 | 0 | -0.675159 | 1.609438 | 0.684298 | 1.069249 | 0.523685 | 0.0 |
| 212261 | 0 | 0 | 1 | 0 | 0.257579 | 2.372401 | 0.925793 | 0.990173 | -0.180504 | 0.226447 | ... | 0.029750 | 0 | 0.595178 | 0 | -0.675159 | 0.000000 | -0.090805 | 0.502197 | -0.012185 | 0.0 |
| 213929 | 0 | 0 | 1 | 0 | -0.418476 | -0.510160 | 0.925793 | 0.944173 | -0.180504 | -0.054619 | ... | 1.581774 | 0 | 0.595178 | 0 | -0.675159 | 0.693147 | 0.130952 | -0.651068 | 0.757100 | 0.0 |
| 198449 | 1 | 0 | 0 | 0 | -0.418476 | -0.651175 | 0.925793 | -0.954585 | -0.180504 | -0.054619 | ... | -0.797732 | 0 | 0.595178 | 1 | 1.481133 | 0.000000 | -1.564675 | -0.719433 | -0.964353 | 0.0 |
997 rows × 49 columns
test_df_labels_sample
196995 1
208764 1
198537 0
199457 0
204573 0
..
213637 1
214221 1
212261 1
213929 1
198449 1
Name: retweet_label, Length: 997, dtype: int64
test_df_conf_sample = model.predict_proba(test_df_sample.values)
test_df_conf_sample
[Parallel(n_jobs=24)]: Using backend ThreadingBackend with 24 concurrent workers. [Parallel(n_jobs=24)]: Done 2 tasks | elapsed: 0.0s [Parallel(n_jobs=24)]: Done 152 tasks | elapsed: 0.0s [Parallel(n_jobs=24)]: Done 200 out of 200 | elapsed: 0.1s finished
array([[0.5306387 , 0.4693613 ],
[0.39770972, 0.60229028],
[0.77102733, 0.22897267],
...,
[0.40260657, 0.59739343],
[0.58668253, 0.41331747],
[0.70577764, 0.29422236]])
shap_values = explainer(test_df_sample)
shap_values.base_values.shape
(997, 2)
shap_values.values.shape
(997, 49, 2)
shap_values.data.shape
(997, 49)
# visualize the prediction's explanation for class 0 for a confident correct prediction
# note idx 2 in test_df_conf_sample: 0.771 confidence
idx = 2
exp = shap.Explanation(shap_values.values[:, :, 0], shap_values.base_values[:, 0], shap_values.data, test_df_sample)
shap.plots.waterfall(exp[idx], max_display=30)
# visualize the prediction's explanation for class 1 as a check (symmetric chart)
idx = 2 # the same example
exp = shap.Explanation(shap_values.values[:, :, 1], shap_values.base_values[:, 1], shap_values.data, test_df_sample)
shap.plots.waterfall(exp[idx], max_display=30)
# the same plot for class 1, but horizontal display
idx = 2
shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values.values[idx, :, 1], test_df_sample.iloc[idx, :])
# Vizualize multiple predictions (class 1)
# It is possible to explore different variables interactively in the notebook
shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values.values[:, :, 1], test_df_sample)
# influence on class 1
shap.summary_plot(shap_values.values[:, :, 1], test_df_sample, plot_type='dot')
clust = shap.utils.hclust(test_df_sample, test_df_labels_sample, linkage="single")
`early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead. No/low signal found from feature 2 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1). No/low signal found from feature 3 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1). 84%|████████▎ | 41/49 [00:11<00:03, 2.35it/s]No/low signal found from feature 40 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1). No/low signal found from feature 41 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1). No/low signal found from feature 42 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1). 100%|██████████| 49/49 [00:13<00:00, 3.03it/s]No/low signal found from feature 48 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1). 50it [00:13, 1.04s/it]
exp = shap.Explanation(shap_values.values[:, :, 1], shap_values.base_values[:, 1], shap_values.data, test_df_sample, feature_names=test_df_sample.columns)
shap.plots.bar(exp, max_display=48, clustering=clust, clustering_cutoff=1)
"SHAP dependence plots show the effect of a single feature across the whole dataset. They plot a feature's value vs. the SHAP value of that feature across many samples. SHAP dependence plots are similar to partial dependence plots, but account for the interaction effects present in the features, and are only defined in regions of the input space supported by data. The vertical dispersion of SHAP values at a single feature value is driven by interaction effects, and another feature is chosen for coloring to highlight possible interactions."
for name in test_df_sample.columns:
shap.dependence_plot(name, shap_values.values[:, :, 1], test_df_sample, display_features=test_df_sample)